{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d503406",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install -U huggingface_hub datasets qdrant-client seaborn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1af4a9ca-bee3-433a-a87a-64bc4c55af75",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from tqdm import tqdm\n",
    "\n",
    "from datasets import load_dataset\n",
    "from qdrant_client import QdrantClient, models\n",
    "from qdrant_client.models import PointStruct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b0843f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcc0520d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"nirantk/dbpedia-entities-google-palm-gemini-embedding-001-100K\", streaming=False, split='train')\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47e3baef-2746-47ee-a195-1f1ba37b44b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = QdrantClient(\n",
    "    url=\"https://a4197291-1236-40e0-bf18-18e8843a05a2.us-east4-0.gcp.cloud.qdrant.io:6333\", \n",
    "    api_key=os.getenv(\"QDRANT_API_KEY\"),\n",
    "    timeout=100,\n",
    "    prefer_grpc=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cd6a7e3",
   "metadata": {},
   "source": [
    "# Setting up a Collection with Binary Quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1c8eb96",
   "metadata": {},
   "outputs": [],
   "source": [
    "collection_name = \"gemini-embedding-001\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db5bb6fb-f93e-42c8-b440-d0b771ef8dab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# client.recreate_collection(\n",
    "#     collection_name=f\"{collection_name}\",\n",
    "#     vectors_config=models.VectorParams(\n",
    "#         size=768,\n",
    "#         distance=models.Distance.COSINE,\n",
    "#         on_disk=True,\n",
    "#     ),\n",
    "#     optimizers_config=models.OptimizersConfigDiff(\n",
    "#         default_segment_number=5,\n",
    "#         indexing_threshold=0,\n",
    "#     ),\n",
    "#     quantization_config=models.BinaryQuantization(\n",
    "#         binary=models.BinaryQuantizationConfig(always_ram=True),\n",
    "#     ),\n",
    "#     shard_number=2,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60426ff7-52f8-4974-9e6e-56087863a0e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "collection_info = client.get_collection(collection_name=collection_name)\n",
    "\n",
    "points=[\n",
    "        {\n",
    "            \"id\": i,\n",
    "            \"vector\": embedding,\n",
    "            \"payload\": {\"text\": data[\"text\"], \"title\": data[\"title\"]}\n",
    "        }\n",
    "        for i, (embedding, data) in enumerate(zip(dataset[\"embedding\"], dataset))\n",
    "    ]\n",
    "points = [PointStruct(**point) for point in points]\n",
    "\n",
    "if collection_info.vectors_count == 0:\n",
    "    print(\"Collection is empty. Begin indexing.\")\n",
    "    bs = 100 # Batch size    \n",
    "    for i in tqdm(range(0, len(points), bs)):\n",
    "        slice_points = points[i:i+bs]  # Create a slice of bs points\n",
    "        client.upsert(\n",
    "            collection_name=collection_name,\n",
    "            points=slice_points\n",
    "        )\n",
    "    client.update_collection(\n",
    "        collection_name=f\"{collection_name}\",\n",
    "        optimizer_config=models.OptimizersConfigDiff(\n",
    "            indexing_threshold=20000\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3640b36-adcb-41f5-9e2c-e3ec74b1d82f",
   "metadata": {},
   "outputs": [],
   "source": [
    "collection_info = client.get_collection(collection_name=collection_name)\n",
    "collection_info.vectors_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac42c816-90c8-4a2b-9358-ed9f40283cc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "client.search(\n",
    "    collection_name=f\"{collection_name}\",\n",
    "    query_vector=points[32].vector,\n",
    "    search_params=models.SearchParams(\n",
    "        quantization=models.QuantizationSearchParams(\n",
    "            ignore=False,\n",
    "            rescore=False,\n",
    "            oversampling=2.0,\n",
    "        ),\n",
    "        exact=True,\n",
    "    ),\n",
    "    limit=5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c59e5ae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = dataset.train_test_split(test_size=0.1, shuffle=True, seed=37)['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b35e254",
   "metadata": {},
   "outputs": [],
   "source": [
    "oversampling_range = np.arange(1.0, 3.1, 1.0)\n",
    "rescore_range = [True, False]\n",
    "\n",
    "def parameterized_search(\n",
    "        point, \n",
    "        oversampling: float, \n",
    "        rescore: bool, \n",
    "        exact: bool, \n",
    "        collection_name: str, \n",
    "        ignore: bool = False,\n",
    "        limit: int = 10\n",
    "    ):\n",
    "    if exact:\n",
    "        return client.search(\n",
    "            collection_name=collection_name,\n",
    "            query_vector=point.vector,\n",
    "            search_params=models.SearchParams(exact=exact),\n",
    "            limit=limit\n",
    "        )\n",
    "    else:\n",
    "        return client.search(\n",
    "            collection_name=collection_name,\n",
    "            query_vector=point.vector,\n",
    "            search_params=models.SearchParams(\n",
    "                quantization=models.QuantizationSearchParams(\n",
    "                    ignore=ignore,\n",
    "                    rescore=rescore,\n",
    "                    oversampling=oversampling,\n",
    "                ),\n",
    "                exact=exact,     \n",
    "            ),\n",
    "            limit=limit\n",
    "        )\n",
    "\n",
    "import loguru\n",
    "\n",
    "logger = loguru.logger\n",
    "logger.add(\"logs.log\", format=\"{time} {level} {message}\", level=\"INFO\")\n",
    "\n",
    "results = []\n",
    "with open(\"results.json\", \"w+\") as f:\n",
    "    for point in tqdm(points[10:100]):\n",
    "        # print(element.payload[\"text\"])\n",
    "        # print(\"Oversampling\")\n",
    "\n",
    "        ## Running Grid Search\n",
    "        for oversampling in oversampling_range:\n",
    "            for rescore in rescore_range:\n",
    "                limit_range = [100, 50, 20, 10, 5, 1]\n",
    "                for limit in limit_range:\n",
    "                    try:\n",
    "                        exact = parameterized_search(point=point, oversampling=oversampling, rescore=rescore, exact=True, collection_name=collection_name, limit=limit)\n",
    "                        hnsw = parameterized_search(point=point, oversampling=oversampling, rescore=rescore, exact=False, collection_name=collection_name, limit=limit)\n",
    "                    except Exception as e:\n",
    "                        print(f\"Skipping point: {point}\\n{e}\")\n",
    "                        continue\n",
    "\n",
    "                    exact_ids = [item.id for item in exact]\n",
    "                    hnsw_ids = [item.id for item in hnsw]\n",
    "                    logger.info(f\"Exact: {exact_ids}\")\n",
    "                    logger.info(f\"HNSW: {hnsw_ids}\")\n",
    "\n",
    "                    accuracy = len(set(exact_ids) & set(hnsw_ids)) / len(exact_ids)\n",
    "\n",
    "                    if accuracy is None:\n",
    "                        continue\n",
    "\n",
    "                    result = {\n",
    "                        \"query_id\": point.id,\n",
    "                        \"oversampling\": oversampling,\n",
    "                        \"rescore\": rescore,\n",
    "                        \"limit\": limit,\n",
    "                        \"accuracy\": accuracy,\n",
    "                    }\n",
    "                    f.write(json.dumps(result))\n",
    "                    f.write(\"\\n\")\n",
    "                    logger.info(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10bf7f19",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "results = pd.read_json(\"results.json\", lines=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b99eb24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# results.to_csv(\"results.csv\", index=False)\n",
    "average_accuracy = results[results['limit'] != 1]\n",
    "average_accuracy = average_accuracy[average_accuracy['limit'] != 5]\n",
    "average_accuracy = average_accuracy.groupby(['oversampling', 'rescore', 'limit'])['accuracy'].mean()\n",
    "average_accuracy = average_accuracy.reset_index()\n",
    "acc = average_accuracy.pivot(index='limit', columns=['oversampling', 'rescore'], values='accuracy')\n",
    "acc"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
